Skip to content

fix: validate inputs and preserve numpy return type in calc_quantile_loss#839

Merged
WenjieDu merged 2 commits into
WenjieDu:devfrom
shaun0927:fix/quantile-loss-validation-and-dtype
Apr 25, 2026
Merged

fix: validate inputs and preserve numpy return type in calc_quantile_loss#839
WenjieDu merged 2 commits into
WenjieDu:devfrom
shaun0927:fix/quantile-loss-validation-and-dtype

Conversation

@shaun0927
Copy link
Copy Markdown

Description

Two small, related fixes to pypots/nn/functional/error.py that make
calc_quantile_loss behave like its siblings in the same module.

1. Missing _check_inputs call.
Every other calc_* in the file (calc_mae, calc_mse, calc_rmse,
calc_mre, calc_quantile_crps, calc_quantile_crps_sum) starts with
lib = _check_inputs(...) so NaN / shape / dtype problems surface as a
clear AssertionError. calc_quantile_loss was the one exception and
silently returned tensor(nan) on NaN input, e.g.

calc_mae(p_with_nan, t, m)              # AssertionError, good
calc_quantile_loss(p_with_nan, t, 0.5, m)  # tensor(nan), bad

Adding the same guard makes the module's validation contract uniform.

2. numpy-in / numpy-out parity.
The function signature returns Union[float, torch.Tensor], matching the
sibling metrics which preserve the caller's type (numpy in → numpy out).
After #822 added numpy support, calc_quantile_loss always converted
numpy inputs to torch tensors and returned a torch.Tensor, breaking
that contract:

type(calc_mae(np_a, np_b))             # numpy.float32
type(calc_quantile_loss(np_a, np_b, 0.5, np_m))  # torch.Tensor  <-- regression

Restoring the numpy return on numpy inputs keeps downstream code that
expected a plain numpy scalar working.

Changes

  • pypots/nn/functional/error.pycalc_quantile_loss:
    • call _check_inputs(predictions, targets, eval_points) first;
    • capture numpy_in = isinstance(predictions, np.ndarray) before the
      conversion block;
    • when numpy_in, return quantile_loss.detach().cpu().numpy() so the
      declared Union[float, torch.Tensor] contract holds.

The change is confined to one function and does not touch the
_check_inputs signature, so it does not conflict with the in-flight #821.

Testing

  • NaN input now raises AssertionError: predictions mustn't contain NaN values, matching calc_mae / calc_mse / …
  • Shape mismatch raises AssertionError: shape of predictions and targets must match … instead of relying on downstream broadcast errors.
  • numpy in → numpy out is preserved; torch in → torch out is unchanged.
  • calc_quantile_crps still calls calc_quantile_loss internally on
    torch tensors; that path is numerically unchanged.

shaun0927 added 2 commits April 17, 2026 13:08
…loss

Two related consistency fixes for pypots/nn/functional/error.py:

1. calc_quantile_loss was the only calc_* function that did not call
   _check_inputs(). NaN or shape-mismatched inputs therefore produced
   silently wrong numeric results instead of the clear AssertionError
   that calc_mae/calc_mse/calc_rmse/calc_mre raise. This adds the same
   guard so all error metrics share a single validation contract.

2. After numpy support was introduced in WenjieDu#822, numpy inputs were
   always returned as a torch.Tensor, breaking the existing
   Union[float, torch.Tensor] contract that sibling metrics honor
   (numpy in -> numpy out). The function now converts back to a numpy
   scalar when the caller passed numpy arrays.

Verified with the sibling metrics on both numpy and torch paths.
…s_sum working

The initial fix passed _check_inputs with default check_shape=True, which
rejects the intentional broadcasting in the calc_quantile_crps_sum code
path (q_pred has shape (B,) while targets has shape (B, T)). Passing
check_shape=False keeps the NaN/type guards that motivated this change
while allowing both internal callers (calc_quantile_crps and
calc_quantile_crps_sum) to keep broadcasting as they did before.

_check_inputs still validates mask.shape == targets.shape, so the mask
contract is unchanged.
@sonarqubecloud
Copy link
Copy Markdown

@WenjieDu WenjieDu changed the base branch from main to dev April 25, 2026 18:29
@WenjieDu WenjieDu merged commit c7da52b into WenjieDu:dev Apr 25, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants